package io.unitycatalog.server.service;

import static io.unitycatalog.server.model.SecurableType.CATALOG;
import static io.unitycatalog.server.model.SecurableType.METASTORE;
import static io.unitycatalog.server.model.SecurableType.REGISTERED_MODEL;
import static io.unitycatalog.server.model.SecurableType.SCHEMA;
import static io.unitycatalog.server.service.credential.CredentialContext.Privilege.SELECT;
import static io.unitycatalog.server.service.credential.CredentialContext.Privilege.UPDATE;

import io.unitycatalog.server.auth.UnityCatalogAuthorizer;
import io.unitycatalog.server.auth.decorator.KeyMapper;
import io.unitycatalog.server.auth.decorator.UnityAccessEvaluator;
import io.unitycatalog.server.exception.BaseException;
import io.unitycatalog.server.exception.ErrorCode;
import io.unitycatalog.server.exception.GlobalExceptionHandler;
import io.unitycatalog.server.model.GenerateTemporaryModelVersionCredential;
import io.unitycatalog.server.model.ModelVersionInfo;
import io.unitycatalog.server.model.ModelVersionOperation;
import io.unitycatalog.server.model.ModelVersionStatus;
import io.unitycatalog.server.model.SecurableType;
import io.unitycatalog.server.persist.ModelRepository;
import io.unitycatalog.server.persist.Repositories;
import io.unitycatalog.server.persist.UserRepository;
import io.unitycatalog.server.persist.utils.RepositoryUtils;
import io.unitycatalog.server.service.credential.CloudCredentialVendor;
import io.unitycatalog.server.service.credential.CredentialContext;
import java.util.Map;
import java.util.Set;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.server.annotation.ExceptionHandler;
import com.linecorp.armeria.server.annotation.Post;
import lombok.SneakyThrows;

@ExceptionHandler(GlobalExceptionHandler.class)
public class TemporaryModelVersionCredentialsService {
  private final ModelRepository modelRepository;
  private final UserRepository userRepository;

  private final UnityAccessEvaluator evaluator;
  private final CloudCredentialVendor cloudCredentialVendor;
  private final KeyMapper keyMapper;

  @SneakyThrows
  public TemporaryModelVersionCredentialsService(UnityCatalogAuthorizer authorizer,
                                                 CloudCredentialVendor cloudCredentialVendor,
                                                 Repositories repositories) {
    this.evaluator = new UnityAccessEvaluator(authorizer);
    this.cloudCredentialVendor = cloudCredentialVendor;
    this.keyMapper = new KeyMapper(repositories);
    this.modelRepository = repositories.getModelRepository();
    this.userRepository = repositories.getUserRepository();
  }

  @Post("")
  public HttpResponse generateTemporaryModelVersionCredentials(
      GenerateTemporaryModelVersionCredential generateTemporaryModelVersionCredentials) {
    authorizeForOperation(generateTemporaryModelVersionCredentials);

    long modelVersion = generateTemporaryModelVersionCredentials.getVersion();
    String catalogName = generateTemporaryModelVersionCredentials.getCatalogName();
    String schemaName = generateTemporaryModelVersionCredentials.getSchemaName();
    String modelName = generateTemporaryModelVersionCredentials.getModelName();
    String fullName = RepositoryUtils.getAssetFullName(catalogName, schemaName, modelName);

    ModelVersionInfo modelVersionInfo = modelRepository.getModelVersion(fullName, modelVersion);
    String storageLocation = modelVersionInfo.getStorageLocation();
    if (storageLocation.toLowerCase().startsWith("file")) {
      String errorMsg = String.format(
          "Cannot request credentials on a model version with a file based storage location: %s/%d",
          fullName, modelVersion);
      throw new BaseException(ErrorCode.INVALID_ARGUMENT, errorMsg);
    }
    ModelVersionOperation requestedOperation =
        generateTemporaryModelVersionCredentials.getOperation();
    // Must enforce that the status of the model version matches the requested credential type.
    if (modelVersionInfo.getStatus() == ModelVersionStatus.FAILED_REGISTRATION
        || modelVersionInfo.getStatus() == ModelVersionStatus.MODEL_VERSION_STATUS_UNKNOWN) {
      String errorMsg = String.format(
          "Cannot request credentials on a model version with status %s: %s/%d",
          modelVersionInfo.getStatus().getValue(), fullName, modelVersion);
      throw new BaseException(ErrorCode.INVALID_ARGUMENT, errorMsg);
    }
    if ((modelVersionInfo.getStatus() != ModelVersionStatus.PENDING_REGISTRATION
        && requestedOperation == ModelVersionOperation.READ_WRITE_MODEL_VERSION)) {
      String errorMsg = String.format(
          "Cannot request read/write credentials on a model version that has been finalized: %s/%d",
          fullName, modelVersion);
      throw new BaseException(ErrorCode.INVALID_ARGUMENT, errorMsg);
    }
    return HttpResponse.ofJson(
        cloudCredentialVendor.vendCredential(
            modelVersionInfo.getStorageLocation(),
            modelVersionOperationToPrivileges(requestedOperation)));
  }

  private Set<CredentialContext.Privilege> modelVersionOperationToPrivileges(
      ModelVersionOperation modelVersionOperation) {
    return switch (modelVersionOperation) {
      case READ_MODEL_VERSION -> Set.of(SELECT);
      case READ_WRITE_MODEL_VERSION -> Set.of(SELECT, UPDATE);
      case UNKNOWN_MODEL_VERSION_OPERATION -> throw new BaseException(ErrorCode.INVALID_ARGUMENT,
          "Unknown operation in the request: " + ModelVersionOperation.UNKNOWN_MODEL_VERSION_OPERATION);
    };
  }

  private void authorizeForOperation(
      GenerateTemporaryModelVersionCredential generateTemporaryModelVersionCredentials) {
    // TODO: This is a short term solution to conditional expression evaluation based on additional
    // request parameters. This should be replaced with more direct annotations and syntax in the
    // future.

    String readExpression = """
        #authorizeAny(#principal, #registered_model, OWNER, EXECUTE) &&
        #authorizeAny(#principal, #schema, OWNER, USE_SCHEMA) &&
        #authorizeAny(#principal, #catalog, OWNER, USE_CATALOG)
        """;

    String writeExpression = """
        #authorize(#principal, #registered_model, OWNER) &&
        #authorizeAny(#principal, #schema, OWNER, USE_SCHEMA) &&
        #authorizeAny(#principal, #catalog, OWNER, USE_CATALOG)
        """;

    String authorizeExpression =
        generateTemporaryModelVersionCredentials.getOperation() ==
            ModelVersionOperation.READ_MODEL_VERSION ? readExpression : writeExpression;

    Map<SecurableType, Object> resourceKeys = keyMapper.mapResourceKeys(
        Map.of(METASTORE, "metastore",
            CATALOG, generateTemporaryModelVersionCredentials.getCatalogName(),
            SCHEMA, generateTemporaryModelVersionCredentials.getSchemaName(),
            REGISTERED_MODEL, generateTemporaryModelVersionCredentials.getModelName()));

    if (!evaluator.evaluate(userRepository.findPrincipalId(), authorizeExpression, resourceKeys)) {
      throw new BaseException(ErrorCode.PERMISSION_DENIED, "Access denied.");
    }
  }
}
